{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 5.10 批量归一化"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.4.0\n",
      "cuda\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"..\") \n",
    "import d2lzh_pytorch as d2l\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "print(torch.__version__)\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5.10.2 从零开始实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def batch_norm(is_training, X, gamma, beta, moving_mean, moving_var, eps, momentum):\n",
    "    # 判断当前模式是训练模式还是预测模式\n",
    "    if not is_training:\n",
    "        # 如果是在预测模式下，直接使用传入的移动平均所得的均值和方差\n",
    "        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)\n",
    "    else:\n",
    "        assert len(X.shape) in (2, 4)\n",
    "        if len(X.shape) == 2:\n",
    "            # 使用全连接层的情况，计算特征维上的均值和方差\n",
    "            mean = X.mean(dim=0)\n",
    "            var = ((X - mean) ** 2).mean(dim=0)\n",
    "        else:\n",
    "            # 使用二维卷积层的情况，计算通道维上（axis=1）的均值和方差。这里我们需要保持\n",
    "            # X的形状以便后面可以做广播运算\n",
    "            mean = X.mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)\n",
    "            var = ((X - mean) ** 2).mean(dim=0, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)\n",
    "        # 训练模式下用当前的均值和方差做标准化\n",
    "        X_hat = (X - mean) / torch.sqrt(var + eps)\n",
    "        # 更新移动平均的均值和方差\n",
    "        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean\n",
    "        moving_var = momentum * moving_var + (1.0 - momentum) * var\n",
    "    Y = gamma * X_hat + beta  # 拉伸和偏移\n",
    "    return Y, moving_mean, moving_var"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BatchNorm(nn.Module):\n",
    "    def __init__(self, num_features, num_dims):\n",
    "        super(BatchNorm, self).__init__()\n",
    "        if num_dims == 2:\n",
    "            shape = (1, num_features)\n",
    "        else:\n",
    "            shape = (1, num_features, 1, 1)\n",
    "        # 参与求梯度和迭代的拉伸和偏移参数，分别初始化成0和1\n",
    "        self.gamma = nn.Parameter(torch.ones(shape))\n",
    "        self.beta = nn.Parameter(torch.zeros(shape))\n",
    "        # 不参与求梯度和迭代的变量，全在内存上初始化成0\n",
    "        self.moving_mean = torch.zeros(shape)\n",
    "        self.moving_var = torch.zeros(shape)\n",
    "\n",
    "    def forward(self, X):\n",
    "        # 如果X不在内存上，将moving_mean和moving_var复制到X所在显存上\n",
    "        if self.moving_mean.device != X.device:\n",
    "            self.moving_mean = self.moving_mean.to(X.device)\n",
    "            self.moving_var = self.moving_var.to(X.device)\n",
    "        # 保存更新过的moving_mean和moving_var, Module实例的traning属性默认为true, 调用.eval()后设成false\n",
    "        Y, self.moving_mean, self.moving_var = batch_norm(self.training, \n",
    "            X, self.gamma, self.beta, self.moving_mean,\n",
    "            self.moving_var, eps=1e-5, momentum=0.9)\n",
    "        return Y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.10.2.1 使用批量归一化层的LeNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = nn.Sequential(\n",
    "            nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_size\n",
    "            BatchNorm(6, num_dims=4),\n",
    "            nn.Sigmoid(),\n",
    "            nn.MaxPool2d(2, 2), # kernel_size, stride\n",
    "            nn.Conv2d(6, 16, 5),\n",
    "            BatchNorm(16, num_dims=4),\n",
    "            nn.Sigmoid(),\n",
    "            nn.MaxPool2d(2, 2),\n",
    "            d2l.FlattenLayer(),\n",
    "            nn.Linear(16*4*4, 120),\n",
    "            BatchNorm(120, num_dims=2),\n",
    "            nn.Sigmoid(),\n",
    "            nn.Linear(120, 84),\n",
    "            BatchNorm(84, num_dims=2),\n",
    "            nn.Sigmoid(),\n",
    "            nn.Linear(84, 10)\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training on  cuda\n",
      "epoch 1, loss 0.0039, train acc 0.790, test acc 0.835, time 2.9 sec\n",
      "epoch 2, loss 0.0018, train acc 0.866, test acc 0.821, time 3.2 sec\n",
      "epoch 3, loss 0.0014, train acc 0.879, test acc 0.857, time 2.6 sec\n",
      "epoch 4, loss 0.0013, train acc 0.886, test acc 0.820, time 2.7 sec\n",
      "epoch 5, loss 0.0012, train acc 0.891, test acc 0.859, time 2.8 sec\n"
     ]
    }
   ],
   "source": [
    "batch_size = 256\n",
    "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)\n",
    "\n",
    "lr, num_epochs = 0.001, 5\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([ 1.2537,  1.2284,  1.0100,  1.0171,  0.9809,  1.1870], device='cuda:0'),\n",
       " tensor([ 0.0962,  0.3299, -0.5506,  0.1522, -0.1556,  0.2240], device='cuda:0'))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net[1].gamma.view((-1,)), net[1].beta.view((-1,))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5.10.3 简洁实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = nn.Sequential(\n",
    "            nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_size\n",
    "            nn.BatchNorm2d(6),\n",
    "            nn.Sigmoid(),\n",
    "            nn.MaxPool2d(2, 2), # kernel_size, stride\n",
    "            nn.Conv2d(6, 16, 5),\n",
    "            nn.BatchNorm2d(16),\n",
    "            nn.Sigmoid(),\n",
    "            nn.MaxPool2d(2, 2),\n",
    "            d2l.FlattenLayer(),\n",
    "            nn.Linear(16*4*4, 120),\n",
    "            nn.BatchNorm1d(120),\n",
    "            nn.Sigmoid(),\n",
    "            nn.Linear(120, 84),\n",
    "            nn.BatchNorm1d(84),\n",
    "            nn.Sigmoid(),\n",
    "            nn.Linear(84, 10)\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training on  cuda\n",
      "epoch 1, loss 0.0054, train acc 0.767, test acc 0.795, time 2.0 sec\n",
      "epoch 2, loss 0.0024, train acc 0.851, test acc 0.748, time 2.0 sec\n",
      "epoch 3, loss 0.0017, train acc 0.872, test acc 0.814, time 2.2 sec\n",
      "epoch 4, loss 0.0014, train acc 0.883, test acc 0.818, time 2.1 sec\n",
      "epoch 5, loss 0.0013, train acc 0.889, test acc 0.734, time 1.8 sec\n"
     ]
    }
   ],
   "source": [
    "batch_size = 256\n",
    "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)\n",
    "\n",
    "lr, num_epochs = 0.001, 5\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
